//! Mathematical operations and utilities.
//!
//! Provides extended GCD, modular inverse, and modular arithmetic operations.

use crate::RangeCheck;
use crate::integer::{U128MulGuarantee, u256_wide_mul, u512_safe_div_rem_by_u256};
#[allow(unused_imports)]
use crate::option::OptionTrait;
#[allow(unused_imports)]
use crate::traits::{Into, TryInto};
#[allow(unused_imports)]
use crate::zeroable::{IsZeroResult, NonZeroIntoImpl, Zeroable};

/// Computes the extended GCD and Bezout coefficients for two numbers.
///
/// Uses the Extended Euclidean algorithm to find (g, s, t, sub_direction) where `g = gcd(a, b)`.
/// The relationship between inputs and outputs is:
/// * If `sub_direction` is true:  `g = s * a - t * b`
/// * If `sub_direction` is false: `g = t * b - s * a`
///
/// Returns a tuple (g, s, t, sub_direction) where g is the GCD and `(s, -t)` or `(-s, t)` are the
/// Bezout coefficients (according to `sub_direction`).
///
/// # Examples
///
/// ```
/// use core::math::egcd;
///
/// let (g, s, t, dir) = egcd::<u32>(12, 8);
/// assert!(g == 4);
/// if dir {
///     assert!(s * 12 - t * 8 == 4);
/// } else {
///     assert!(t * 8 - s * 12 == 4);
/// }
/// let (g, s, t, dir) = egcd::<i64>(-3, 15);
/// assert!(g == 3);
/// if dir {
///     assert!(s * -3 - t * 15 == 3);
/// } else {
///     assert!(t * 15 - s * -3 == 3);
/// }
/// ```
pub fn egcd<
    T,
    +Copy<T>,
    +Drop<T>,
    +Add<T>,
    +Mul<T>,
    +DivRem<T>,
    +core::num::traits::Zero<T>,
    +core::num::traits::One<T>,
    +TryInto<T, NonZero<T>>,
    +Sub<T>,
    +PartialOrd<T>,
>(
    a: NonZero<T>, b: NonZero<T>,
) -> (T, T, T, bool) {
    let (q, r) = DivRem::<T>::div_rem(a.into(), b);
    let Some(r) = r.try_into() else {
        let (abs_b, sign) = abs_and_sign(b.into());
        return (abs_b, core::num::traits::Zero::zero(), core::num::traits::One::one(), sign);
    };

    // `sign` (1 for true, -1 for false) is the sign of `g` in the current iteration.
    // 0 is considered negative for this purpose.
    let (g, s, t, sign) = egcd(b, r);
    // From the recursive call, we know that `s*b - t*r = sign*g`.
    // From the division above, we have `a = q*b + r`, which gives us `r = a - q*b`.
    // Substituting `r` into the equation:
    //   `s*b - t*r = sign*g`
    //   `s*b - t*(a - q*b) = sign*g`
    //   `s*b - t*a + t*q*b = sign*g`
    //   `(s + q*t)*b - t*a = sign*g`
    // Rearranging:
    //   `t*a - (s + q*t)*b = (-sign)*g`
    // Thus we pick `new_s = t`, `new_t = s + q*t`, `new_sign = !sign`.
    (g, t, s + q * t, !sign)
}

/// Computes the modular multiplicative inverse of `a` modulo `n`.
///
/// Returns `s` such that `a*s ≡ 1 (mod n)` where `s` is between `1` and `|n|-1` inclusive, or
/// `None` if `gcd(a,n) > 1` (inverse doesn't exist).
///
/// # Note
///
/// We consider the cases of negative `n` to be equivalent to the cases of positive `n`, as it
/// defines the same equivalence classes.
///
/// # Examples
///
/// ```
/// use core::math::inv_mod;
///
/// assert!(inv_mod::<u32>(3, 7) == Some(5));
/// assert!(inv_mod::<u32>(3, 9) == None);
/// ```
pub fn inv_mod<
    T,
    +Copy<T>,
    +Drop<T>,
    +Add<T>,
    +Sub<T>,
    +Mul<T>,
    +DivRem<T>,
    +core::num::traits::Zero<T>,
    +core::num::traits::One<T>,
    +TryInto<T, NonZero<T>>,
    +PartialOrd<T>,
>(
    a: NonZero<T>, n: NonZero<T>,
) -> Option<T> {
    // In any case, normalizing `n` to be positive.
    let (n_abs, _) = abs_and_sign(n.into());
    if core::num::traits::One::<T>::is_one(@n_abs) {
        return None;
    }
    let (g, s, _, sub_direction) = egcd(a, n);
    if g.is_one() {
        let (s_abs, s_sign) = abs_and_sign(s);
        // `1 = g = gcd(a, n) = +-(s*a - t*n) => s*a = +-1 (mod n)`.
        // The absolute values of Bezout coefficients are guaranteed to be `< n`.
        // With n > 1 and gcd = 1, `s` can't be 0.
        if sub_direction ^ s_sign {
            // Both cases are valid:
            // 1. `s` is the Bezout coefficient and `s > 0` so `0 < s < n`.
            // 2. `-s` is the Bezout coefficient and `s < 0` so `0 < -s < n`.
            Some(s_abs)
        } else {
            // Both cases are valid:
            // 1. `-s` is the Bezout coefficient and `s > 0` so
            // `-n < -s < 0 => 0 < n - s < n`, and `n - s = -s (mod n)`.
            // 2. `s` is the Bezout coefficient and `s < 0` so
            // `-n < s < 0 => 0 < n - (-s) < n`, and `n - (-s) = s (mod n)`.
            Some(n_abs - s_abs)
        }
    } else {
        None
    }
}

/// Returns `(|value|, value < 0)`.
fn abs_and_sign<T, +Copy<T>, +Drop<T>, +PartialOrd<T>, +Sub<T>, +core::num::traits::Zero<T>>(
    value: T,
) -> (T, bool) {
    let zero = core::num::traits::Zero::<T>::zero();
    if value < zero {
        (zero - value, true)
    } else {
        (value, false)
    }
}

/// Returns `1 / b (mod n)`, or `Err` if `b` is not invertible modulo `n`.
///
/// All `b`s will be considered not invertible for `n == 1`.
/// Additionally returns several `U128MulGuarantee`s that are required for validating the
/// calculation.
extern fn u256_guarantee_inv_mod_n(
    b: u256, n: NonZero<u256>,
) -> Result<
    (
        NonZero<u256>,
        U128MulGuarantee,
        U128MulGuarantee,
        U128MulGuarantee,
        U128MulGuarantee,
        U128MulGuarantee,
        U128MulGuarantee,
        U128MulGuarantee,
        U128MulGuarantee,
    ),
    (U128MulGuarantee, U128MulGuarantee),
> implicits(RangeCheck) nopanic;

/// Returns the inverse of `a` modulo `n`, or `None` if `a` is not invertible modulo `n`.
///
/// All `a`s will be considered not invertible for `n == 1`.
///
/// # Examples
///
/// ```
/// use core::math::u256_inv_mod;
///
/// let inv = u256_inv_mod(3, 17);
/// assert!(inv == Some(6));
/// ```
#[inline]
pub fn u256_inv_mod(a: u256, n: NonZero<u256>) -> Option<NonZero<u256>> {
    match u256_guarantee_inv_mod_n(a, n) {
        Ok((inv_a, _, _, _, _, _, _, _, _)) => Some(inv_a),
        Err(_) => None(()),
    }
}

/// Returns `a / b (mod n)`, or `None` if `b` is not invertible modulo `n`.
///
/// # Examples
///
/// ```
/// use core::math::u256_div_mod_n;
///
/// let result = u256_div_mod_n(17, 7, 29);
/// assert!(result == Some(19));
/// ```
pub fn u256_div_mod_n(a: u256, b: u256, n: NonZero<u256>) -> Option<u256> {
    Some(u256_mul_mod_n(a, u256_inv_mod(b, n)?.into(), n))
}

/// Returns `a * b (mod n)`.
///
/// # Examples
///
/// ```
/// use core::math::u256_mul_mod_n;
///
/// let result = u256_mul_mod_n(17, 23, 29);
/// assert!(result == 14);
/// ```
pub fn u256_mul_mod_n(a: u256, b: u256, n: NonZero<u256>) -> u256 {
    let (_, r) = u512_safe_div_rem_by_u256(u256_wide_mul(a, b), n);
    r
}

/// A trait for types that have a multiplicative identity element.
trait Oneable<T> {
    /// Returns the multiplicative identity element of Self, 1.
    #[must_use]
    fn one() -> T;

    /// Returns whether self is equal to 1, the multiplicative identity element.
    #[must_use]
    fn is_one(self: T) -> bool;

    /// Returns whether self is not equal to 1, the multiplicative identity element.
    #[must_use]
    fn is_non_one(self: T) -> bool;
}

pub(crate) mod one_based {
    pub(crate) impl OneableImpl<
        T, impl OneImpl: crate::num::traits::One<T>, +Drop<T>, +Copy<T>,
    > of super::Oneable<T> {
        fn one() -> T {
            OneImpl::one()
        }

        #[inline]
        fn is_one(self: T) -> bool {
            OneImpl::is_one(@self)
        }

        #[inline]
        fn is_non_one(self: T) -> bool {
            OneImpl::is_non_one(@self)
        }
    }
}

impl U8Oneable = one_based::OneableImpl<u8>;
impl U16Oneable = one_based::OneableImpl<u16>;
impl U32Oneable = one_based::OneableImpl<u32>;
impl U64Oneable = one_based::OneableImpl<u64>;
impl U128Oneable = one_based::OneableImpl<u128>;
impl U256Oneable = one_based::OneableImpl<u256>;
